import torch
from .naive_self_attention import NaiveAttention
from useful_tools import create_module
import inspect

class create_NaiveAttention(create_module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        output_dim, 
        num_heads, 
        precision=torch.float32, 
        dropout=0.1, 
        require_bias=False, 
        kv_needed=True
        ):
        super(create_NaiveAttention, self).__init__(NaiveAttention)
        
        
    
if __name__=='__main__':
    model=create_NaiveAttention(
        input_dim=8,
        hidden_dim=3,
        output_dim=8,
        num_heads=1
    )
    print(model())